#!/usr/bin/env python3
# coding: utf-8
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""operator dsl function: batchmatmul"""
from functools import reduce
import akg.topi
import akg.tvm
from akg.tvm.hybrid import script
from akg.ops.math import cast
from akg.utils import custom_tiling as ct_util
from akg.utils import validation_check as vc_util
from akg.utils.format_transform import get_shape, get_bytes
from akg.utils.math import greatest_common_divisor, least_common_multiple
from akg.utils.kernel_exec import product_is_mini
from akg.utils import dynamic_shape as ds

batchmatmul_set_dim_map = {
    # 2D
    str((256, 1024, 4096, "float32", False, True)): ((16, 16), (16, 16), (16, 16)),
    str((160, 1024, 1024, "float32", False, False)): ((1, 1), (16, 16), (1024, 1024)),
    str((8192, 1024, 4096, "float32", False, False)): ((1, 1), (1024, 1024), (16, 16)),
    str((1024, 1024, 8192, "float32", True, False)): ((8, 8), (8, 8), (512, 512)),
    str((1024, 1024, 2, "float32", False, False)): ((64, 64), (64, 64), (2, 2)),
    str((1024, 1024, 4096, "float32", False, False)): ((1, 1), (16, 16), (512, 512)),
    str((2, 1024, 8192, "float32", True, False)): ((2, 2), (16, 16), (512, 512)),
    str((30522, 1024, 1280, "float32", True, False)): ((3, 3), (64, 64), (128, 128)),
    str((1024, 4096, 8192, "float32", True, False)): ((32, 32), (32, 32), (32, 32)),
    str((2, 1024, 64, "float32", True, False)): ((4, 4), (64, 64), (64, 64)),
    str((160, 30522, 1024, "float32", False, True)): ((1, 1), (3, 3), (512, 512)),
    str((1024, 1024, 64, "float32", True, False)): ((16, 16), (16, 16), (64, 64)),
    str((4096, 1024, 8192, "float32", True, False)): ((16, 16), (64, 64), (16, 16)),
    str((1280, 1024, 30522, "float32", False, False)): ((1, 1), (512, 512), (3, 3)),
    str((8192, 1024, 4096, "float32", False, True)): ((1, 1), (16, 16), (1024, 1024)),
    str((1280, 30522, 1024, "float32", False, True)): ((1, 1), (3, 3), (512, 512)),
    str((8192, 4096, 1024, "float32", False, False)): ((1, 1), (64, 64), (256, 256)),
    str((768, 768, 8192, "float16", True, False)): ((16, 16), (16, 16), (64, 64)),
    str((3072, 768, 8192, "float16", True, False)): ((16, 16), (16, 16), (64, 64)),
    str((2, 768, 64, "float32", True, False)): ((2, 2), (64, 64), (64, 64)),
    str((768, 1024, 8192, "float16", False, False)): ((16, 1), (16, 1), (64, 1)),
    str((33, 64, 16384, "float32", True, False)): ((1, 1), (16, 16), (128, 128)),
    str((768, 768, 64, "float32", True, False)): ((16, 16), (16, 16), (16, 16)),
    str((8192, 768, 21128, "float32", False, False)): ((4, 4), (128, 128), (4, 4)),
    str((2, 768, 8192, "float32", True, False)): ((2, 2), (16, 16), (512, 512)),
    str((8192, 768, 768, "float16", False, True)): ((1, 1), (16, 16), (768, 768)),
    str((21128, 768, 8192, "float32", True, False)): ((4, 4), (128, 128), (64, 64)),
    str((768, 1024, 768, "float32", True, False)): ((2, 2), (128, 128), (128, 128)),
    str((16384, 16384, 33, "float32", True, False)): ((16, 16), (64, 64), (33, 33)),
    str((21128, 768, 8192, "float32", False, False)): ((4, 4), (128, 128), (64, 64)),
    str((1280, 1280, 1024, "float32", False, True)): ((4, 4), (32, 32), (128, 128)),
    str((1280, 768, 21128, "float32", False, False)): ((1, 1), (768, 768), (8, 8)),
    str((8192, 768, 768, "float32", False, False)): ((1, 1), (8, 8), (768, 768)),
    str((20, 768, 32000, "float32", False, False)): ((20, 20), (48, 48), (32, 32)),
    str((21128, 768, 1280, "float32", True, False)): ((2, 2), (32, 32), (32, 32)),
    str((768, 3072, 1892, "float32", True, False)): ((16, 16), (16, 16), (16, 16)),
    str((33, 64, 16384, "float32", False, True)): ((2, 2), (32, 32), (32, 32)),
    str((8192, 3072, 768, "float32", False, True)): ((16, 16), (16, 16), (16, 16)),
    str((2, 8192, 768, "float32", True, False)): ((16, 16), (16, 16), (16, 16)),
    str((8192, 768, 3072, "float32", False, False)): ((32, 32), (32, 32), (32, 32)),
    str((8192, 768, 3072, "float16", False, True)): ((1, 1), (16, 1), (768, 1)),
    str((8192, 3072, 768, "float16", False, True)): ((1, 1), (16, 1), (768, 1)),
    str((8192, 768, 768, "float16", False, False)): ((1, 1), (16, 1), (768, 1)),
    str((768, 3072, 768, "float16", False, False)): ((1, 1), (16, 1), (768, 1)),
    str((8192, 3072, 768, "float16", False, False)): ((1, 1), (16, 1), (768, 1)),
    str((8192, 768, 3072, "float16", False, False)): ((1, 1), (16, 1), (768, 1)),
    str((21128, 768, 5120, "float32", True, False)): ((4, 4), (128, 128), (64, 64)),
    str((21128, 768, 2560, "float32", True, False)): ((4, 4), (128, 128), (64, 64)),
    str((1024, 2, 4, "float32", True, False)): ((512, 1), (2, 1), (4, 1)),
    str((21128, 1024, 21128, "float32", False, False)): ((4, 4), (512, 512), (4, 4)),
    str((320, 768, 21128, "float32", False, False)): ((40, 40), (128, 128), (4, 4)),
    str((5120, 1024, 21128, "float32", False, False)): ((64, 64), (128, 128), (4, 4)),
    str((16384, 4096, 1024, "float32", False, False)): ((4, 4), (64, 64), (64, 64)),
    str((1024, 4096, 16384, "float32", True, False)): ((64, 64), (64, 64), (4, 4)),
    str((768, 3072, 8192, "float16", True, False)): ((1, 1), (16, 1), (512, 1)),
    str((2048, 768, 3072, "float32", False, False)): ((8, 1), (16, 1), (128, 1)),
    str((1024, 2, 128, "float32", True, False)): ((32, 1), (2, 1), (128, 1)),
    str((1024, 2, 16, "float32", True, False)): ((256, 1), (2, 1), (16, 1)),
    str((1024, 2, 32, "float32", True, False)): ((128, 1), (2, 1), (32, 1)),
    str((1024, 2, 64, "float32", True, False)): ((64, 1), (2, 1), (64, 1)),
    str((1024, 2, 8, "float32", True, False)): ((512, 1), (2, 1), (8, 1)),
    str((768, 2, 128, "float32", True, False)): ((32, 1), (2, 1), (128, 1)),
    str((768, 2, 16, "float32", True, False)): ((256, 1), (2, 1), (16, 1)),
    str((768, 2, 32, "float32", True, False)): ((128, 1), (2, 1), (32, 1)),
    str((768, 2, 64, "float32", True, False)): ((64, 1), (2, 1), (64, 1)),
    str((768, 3072, 2048, "float32", True, False)): ((16, 1), (16, 1), (64, 1)),
    str((3072, 768, 2048, "float32", True, False)): ((16, 1), (16, 1), (64, 1)),
    str((65536, 1024, 4096, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
    str((10240, 1024, 21128, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
    str((10240, 21128, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
    str((21128, 768, 21128, "float32", True, False)): ((1, 1), (768, 1), (16, 1)),
    str((2560, 21128, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
    str((5120, 21128, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
    str((32768, 4096, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
    str((20480, 1024, 21128, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
    str((128, 1024, 4096, "float32", False, True)): ((1, 1), (16, 1), (16, 1)),
    str((128, 1024, 4096, "float16", False, True)): ((1, 1), (16, 1), (32, 1)),
    str((128, 4096, 1024, "float16", False, True)): ((1, 1), (32, 1), (32, 1)),
    str((20480, 21128, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
    str((65536, 4096, 1024, "float32", False, False)): ((1, 1), (1024, 1), (16, 1)),
    str((512, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
    str((256, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
    str((1024, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
    str((2048, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
    str((2, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)),
    str((65536, 768, 3072, "float32", False, True)): ((1, 1), (192, 1), (96, 1)),
    str((16384, 1024, 4096, "float32", False, True)): ((1, 1), (128, 1), (128, 1)),
    str((1024, 4096, 32768, "float32", True, False)): ((8, 1), (1024, 1), (4, 1)),
    str((4, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)),
    str((2, 2, 768, "float32", False, True)): ((1, 1), (2, 1), (768, 1)),
    str((8, 2, 768, "float32", False, True)): ((1, 1), (2, 1), (768, 1)),
    str((4, 2, 768, "float32", False, True)): ((1, 1), (2, 1), (768, 1)),
    # lenet5
    str((32, 10, 84, 'float16', False, True)): ((1, 1), (16, 1), (84, 1)),
    # alexnet
    str((32, 4096, 9216, 'float32', False, False)): ((1, 1), (16, 1), (1024, 1)),
    str((32, 10, 4096, 'float32', False, False)): ((1, 1), (16, 1), (1024, 1)),

    # 3D
    str((128, 128, 64, 1536, "float32", True, False)): ((4, 4), (8, 8), (16, 16), (32, 32)),
    str((128, 768, 128, 64, "float32", False, True)): ((4, 4), (8, 8), (16, 16), (64, 64)),
    str((128, 128, 64, 6144, "float32", True, False)): ((4, 4), (8, 8), (16, 16), (32, 32)),
    str((128, 128, 64, 16384, "float32", True, False)): ((1, 1), (128, 128), (64, 64), (4, 4)),
    str((128, 128, 64, 2048, "float32", True, False)): ((1, 1), (128, 128), (64, 64), (4, 4)),
    str((128, 128, 64, 4096, "float32", True, False)): ((1, 1), (128, 128), (64, 64), (4, 4)),
    str((128, 128, 64, 8192, "float32", True, False)): ((1, 1), (128, 128), (64, 64), (4, 4)),
    str((128, 128, 64, 4096, "float32", True, False)): ((1, 1), (8, 1), (64, 1), (64, 1)),
    str((128, 128, 64, 12288, "float32", True, False)): ((1, 1), (8, 1), (64, 1), (32, 1)),

    # 4D
    str((64, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (12, 12), (8, 8), (8, 8), (32, 32)),

    str((1, 768, 2, "float32", False, False)): ((768, 1), (1, 1)),
    str((20, 768, 21128, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
    str((128, 12, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
    str((128, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
    str((1, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)),
    str((128, 128, 64, 12, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)),
    str((1, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)),
    str((20, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
    str((1, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (8, 1), (64, 1), (32, 1)),
    str((20, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
    str((128, 768, 3072, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
    str((1, 768, 768, "float32", False, False)): ((768, 1), (16, 1)),
    str((21128, 768, 20, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
    str((128, 12, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (128, 1), (64, 1)),
    str((128, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
    str((40, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (16, 1)),
    str((2, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)),
    str((2, 768, 2, "float32", False, False)): ((2, 1), (768, 1), (2, 1)),
    str((40, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
    str((256, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
    str((2, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)),
    str((2, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)),
    str((128, 128, 64, 24, "float32", True, False)): ((1, 1), (2, 1), (64, 1), (12, 1)),
    str((256, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
    str((128, 24, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
    str((2, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
    str((21128, 768, 40, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
    str((40, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (16, 1)),
    str((256, 768, 3072, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
    str((128, 24, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
    str((128, 48, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
    str((512, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
    str((128, 48, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
    str((512, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
    str((4, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)),
    str((512, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
    str((4, 768, 2, "float32", False, False)): ((1, 1), (768, 1), (2, 1)),
    str((80, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
    str((4, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)),
    str((21128, 768, 80, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
    str((80, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
    str((128, 128, 64, 48, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)),
    str((4, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
    str((4, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)),
    str((80, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
    str((128, 96, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
    str((1024, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
    str((160, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
    str((8, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)),
    str((8, 768, 2, "float32", False, False)): ((8, 1), (768, 1), (2, 1)),
    str((8, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
    str((1024, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
    str((128, 96, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
    str((128, 128, 64, 96, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)),
    str((8, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)),
    str((1024, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
    str((21128, 768, 160, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
    str((160, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
    str((8, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)),
    str((160, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
    str((16, 768, 2, "float32", False, False)): ((2, 1), (768, 1), (2, 1)),
    str((320, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
    str((16, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)),
    str((2048, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
    str((16, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)),
    str((128, 128, 64, 192, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)),
    str((128, 192, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
    str((21128, 768, 320, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
    str((128, 192, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
    str((320, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
    str((16, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)),
    str((320, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
    str((2048, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
    str((2048, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
    str((16, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
    str((4096, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
    str((128, 128, 64, 384, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)),
    str((4096, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
    str((640, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
    str((128, 384, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
    str((128, 384, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
    str((32, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)),
    str((32, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)),
    str((32, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)),
    str((32, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
    str((4096, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
    str((640, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
    str((32, 768, 2, "float32", False, False)): ((4, 1), (768, 1), (2, 1)),
    str((21128, 768, 640, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
    str((640, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
    str((64, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)),
    str((8192, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
    str((1280, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
    str((8192, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
    str((128, 768, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
    str((128, 128, 64, 768, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)),
    str((21128, 768, 1280, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
    str((1280, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
    str((64, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)),
    str((1280, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
    str((64, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (16, 1)),
    str((64, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)),
    str((64, 768, 2, "float32", False, False)): ((8, 1), (768, 1), (2, 1)),
    str((8192, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
    str((128, 768, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
    str((128, 12, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (2, 1), (128, 1), (64, 1)),
    str((16384, 768, 3072, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
    str((16384, 768, 768, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
    str((128, 1536, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
    str((128, 768, 2, "float32", False, False)): ((16, 1), (768, 1), (2, 1)),
    str((2560, 768, 768, "float32", False, False)): ((1, 1), (768, 1), (16, 1)),
    str((21128, 768, 2560, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
    str((128, 12, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (1, 1), (64, 1), (32, 1)),
    str((2560, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
    str((128, 128, 64, 1536, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (12, 1)),
    str((128, 1536, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (128, 1)),
    str((2560, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
    str((128, 12, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (8, 1), (64, 1), (32, 1)),
    str((16384, 3072, 768, "float32", False, False)): ((1, 1), (768, 1), (32, 1)),
    str((1, 16, 128, 64, 128, "float32", False, False)): ((1, 1), (4, 1), (64, 1), (8, 1)),
    str((1, 16, 128, 64, 128, "float32", True, False)): ((1, 1), (8, 1), (64, 1), (4, 1)),
    str((128, 16, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)),
    str((128, 1024, 1024, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
    str((20, 1024, 21128, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
    str((128, 4096, 1024, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
    str((128, 1024, 4096, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
    str((128, 16, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)),
    str((20, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)),
    str((1, 16, 128, 128, 64, "float32", False, True)): ((1, 1), (4, 1), (64, 1), (64, 1)),
    str((1, 1024, 2, "float32", False, False)): ((1024, 1), (2, 1)),
    str((21128, 1024, 20, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
    str((20, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
    str((128, 128, 64, 16, "float32", True, False)): ((1, 1), (4, 1), (64, 1), (8, 1)),
    str((1, 1024, 1024, "float32", False, False)): ((1024, 1), (8, 1)),
    str((128, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
    str((1, 1024, 1024, "float32", False, True)): ((8, 1), (1024, 1)),
    str((20, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
    str((128, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
    str((128, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
    str((1, 2, 1024, "float32", False, True)): ((2, 1), (1024, 1)),
    str((2, 1024, 1, "float32", True, False)): ((2, 1), (1024, 1)),
    str((1024, 1024, 128, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
    str((1024, 1024, 1, "float32", True, False)): ((16, 1), (1024, 1)),
    str((1024, 1024, 20, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
    str((4096, 1024, 128, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
    str((1024, 4096, 128, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
    str((128, 128, 64, 32, "float32", True, False)): ((1, 1), (2, 1), (64, 1), (16, 1)),
    str((128, 32, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)),
    str((256, 4096, 1024, "float32", False, False)): ((4, 1), (2048, 1), (2, 1)),
    str((256, 1024, 1024, "float32", False, False)): ((4, 1), (1024, 1), (4, 1)),
    str((40, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)),
    str((2, 16, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (2, 1), (64, 1), (16, 1)),
    str((2, 16, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (4, 1), (64, 1), (8, 1)),
    str((21128, 1024, 40, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
    str((2, 1024, 2, "float32", False, False)): ((2, 1), (1024, 1), (2, 1)),
    str((256, 1024, 4096, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
    str((40, 1024, 21128, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
    str((2, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
    str((40, 1024, 1024, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
    str((128, 32, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)),
    str((2, 16, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (4, 1), (64, 1), (64, 1)),
    str((2, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
    str((256, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
    str((256, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
    str((2, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)),
    str((256, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
    str((40, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
    str((1024, 1024, 256, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
    str((1024, 1024, 40, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
    str((4096, 1024, 256, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
    str((2, 1024, 2, "float32", True, False)): ((2, 1), (1024, 1), (2, 1)),
    str((1024, 1024, 2, "float32", True, False)): ((16, 1), (1024, 1), (1, 1)),
    str((1024, 4096, 256, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
    str((4, 16, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (4, 1), (64, 1), (64, 1)),
    str((4, 16, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (4, 1), (64, 1), (8, 1)),
    str((80, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)),
    str((4, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
    str((128, 64, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)),
    str((128, 128, 64, 64, "float32", True, False)): ((1, 1), (2, 1), (64, 1), (16, 1)),
    str((21128, 1024, 80, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
    str((4, 1024, 2, "float32", False, False)): ((4, 1), (1024, 1), (2, 1)),
    str((128, 64, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)),
    str((80, 1024, 21128, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
    str((512, 1024, 1024, "float32", False, False)): ((4, 1), (1024, 1), (4, 1)),
    str((512, 1024, 4096, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
    str((512, 4096, 1024, "float32", False, False)): ((4, 1), (2048, 1), (2, 1)),
    str((4, 16, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (2, 1), (64, 1), (16, 1)),
    str((80, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
    str((512, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
    str((80, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
    str((4, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
    str((4, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)),
    str((512, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
    str((512, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
    str((1024, 1024, 512, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
    str((2, 1024, 4, "float32", True, False)): ((2, 1), (1024, 1), (4, 1)),
    str((4096, 1024, 512, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
    str((1024, 1024, 4, "float32", True, False)): ((16, 1), (1024, 1), (1, 1)),
    str((1024, 1024, 80, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
    str((1024, 4096, 512, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
    str((4096, 768, 3072, "float32", False, True)): ((1, 1), (8, 1), (3072, 1)),
    str((4096, 3072, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
    str((4096, 768, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
    str((8192, 768, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
    str((8192, 3072, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
    str((8192, 768, 3072, "float32", False, True)): ((1, 1), (8, 1), (3072, 1)),
    str((16384, 3072, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
    str((16384, 768, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
    str((16384, 768, 3072, "float32", False, True)): ((1, 1), (8, 1), (3072, 1)),
    str((3072, 768, 128, "float32", True, False)): ((4, 1), (768, 1), (4, 1)),
    str((768, 768, 1, "float32", True, False)): ((16, 1), (768, 1)),
    str((768, 768, 20, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
    str((768, 3072, 128, "float32", True, False)): ((4, 1), (3072, 1), (1, 1)),
    str((2, 768, 1, "float32", True, False)): ((2, 1), (768, 1)),
    str((768, 768, 128, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
    str((768, 768, 2, "float32", True, False)): ((4, 1), (768, 1), (2, 1)),
    str((768, 768, 40, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
    str((3072, 768, 256, "float32", True, False)): ((4, 1), (768, 1), (4, 1)),
    str((768, 3072, 256, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
    str((768, 768, 256, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
    str((2, 768, 2, "float32", True, False)): ((2, 1), (768, 1), (2, 1)),
    str((768, 768, 80, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
    str((3072, 768, 512, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
    str((768, 3072, 512, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
    str((768, 768, 4, "float32", True, False)): ((4, 1), (768, 1), (2, 1)),
    str((768, 768, 512, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
    str((2, 768, 4, "float32", True, False)): ((2, 1), (768, 1), (4, 1)),
    str((768, 768, 1024, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
    str((768, 3072, 1024, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
    str((3072, 768, 1024, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
    str((768, 768, 8, "float32", True, False)): ((4, 1), (768, 1), (2, 1)),
    str((768, 768, 160, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
    str((2, 768, 8, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
    str((3072, 768, 2048, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
    str((768, 3072, 2048, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
    str((768, 768, 2048, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
    str((768, 768, 320, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
    str((768, 768, 16, "float32", True, False)): ((4, 1), (768, 1), (2, 1)),
    str((2, 768, 16, "float32", True, False)): ((2, 1), (768, 1), (16, 1)),
    str((768, 768, 32, "float32", True, False)): ((4, 1), (768, 1), (2, 1)),
    str((2, 768, 32, "float32", True, False)): ((2, 1), (768, 1), (16, 1)),
    str((3072, 768, 4096, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
    str((768, 3072, 4096, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
    str((768, 768, 640, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
    str((768, 768, 4096, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
    str((768, 768, 64, "float32", True, False)): ((4, 1), (768, 1), (2, 1)),
    str((2, 768, 64, "float32", True, False)): ((2, 1), (768, 1), (16, 1)),
    str((768, 3072, 8192, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
    str((768, 768, 8192, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
    str((768, 768, 1280, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
    str((3072, 768, 8192, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
    str((768, 768, 2560, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
    str((3072, 768, 16384, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
    str((2, 768, 128, "float32", True, False)): ((2, 1), (768, 1), (16, 1)),
    str((768, 768, 16384, "float32", True, False)): ((8, 1), (768, 1), (2, 1)),
    str((768, 3072, 16384, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
    str((8, 16, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (2, 1), (64, 1), (16, 1)),
    str((21128, 1024, 160, "float32", True, False)): ((4, 1), (1024, 1), (4, 1)),
    str((1024, 4096, 1024, "float32", False, False)): ((4, 1), (2048, 1), (2, 1)),
    str((160, 1024, 21128, "float32", False, False)): ((2, 1), (1024, 1), (4, 1)),
    str((8, 16, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (4, 1), (64, 1), (8, 1)),
    str((160, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
    str((8, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
    str((1024, 1024, 1024, "float32", False, False)): ((4, 1), (1024, 1), (4, 1)),
    str((160, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)),
    str((1024, 1024, 4096, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
    str((8, 16, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (4, 1), (64, 1), (64, 1)),
    str((128, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)),
    str((128, 128, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)),
    str((128, 128, 64, 128, "float32", True, False)): ((1, 1), (2, 1), (64, 1), (16, 1)),
    str((8, 1024, 2, "float32", False, False)): ((1, 1), (1024, 1), (1, 1)),
    str((21128, 1024, 320, "float32", True, False)): ((4, 1), (1024, 1), (4, 1)),
    str((16, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
    str((2048, 4096, 1024, "float32", False, False)): ((4, 1), (2048, 1), (2, 1)),
    str((16, 1024, 2, "float32", False, False)): ((1, 1), (1024, 1), (1, 1)),
    str((16, 16, 128, 64, 128, "float32", True, False)): ((1, 1), (1, 1), (4, 1), (64, 1), (8, 1)),
    str((16, 16, 128, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (2, 1), (64, 1), (16, 1)),
    str((2048, 1024, 1024, "float32", False, False)): ((4, 1), (1024, 1), (4, 1)),
    str((320, 1024, 1024, "float32", False, False)): ((1, 1), (1024, 1), (8, 1)),
    str((2048, 1024, 4096, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),
    str((320, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)),
    str((128, 256, 128, 64, "float32", False, True)): ((1, 1), (2, 1), (128, 1), (64, 1)),
    str((320, 1024, 21128, "float32", False, False)): ((2, 1), (1024, 1), (4, 1)),
    str((128, 256, 64, 128, "float32", False, False)): ((1, 1), (1, 1), (64, 1), (32, 1)),
    str((128, 128, 64, 256, "float32", True, False)): ((1, 1), (2, 1), (64, 1), (16, 1)),
    str((16, 16, 128, 128, 64, "float32", False, True)): ((1, 1), (1, 1), (4, 1), (64, 1), (64, 1)),
    str((1024, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
    str((8, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)),
    str((160, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
    str((1024, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
    str((1024, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
    str((8, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
    str((16, 2, 1024, "float32", False, True)): ((1, 1), (2, 1), (1024, 1)),
    str((320, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
    str((16, 1024, 1024, "float32", False, True)): ((1, 1), (8, 1), (1024, 1)),
    str((2048, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
    str((2048, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
    str((2048, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
    str((1024, 4096, 1024, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
    str((1024, 1024, 8, "float32", True, False)): ((16, 1), (1024, 1), (1, 1)),
    str((2, 1024, 8, "float32", True, False)): ((2, 1), (1024, 1), (8, 1)),
    str((1024, 1024, 1024, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
    str((4096, 1024, 1024, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
    str((1024, 1024, 160, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
    str((4096, 1024, 2048, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
    str((1024, 1024, 320, "float32", True, False)): ((8, 1), (1024, 1), (2, 1)),
    str((1024, 1024, 16, "float32", True, False)): ((16, 1), (1024, 1), (1, 1)),
    str((1024, 4096, 2048, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
    str((2, 1024, 16, "float32", True, False)): ((2, 1), (1024, 1), (16, 1)),
    str((1024, 1024, 2048, "float32", True, False)): ((4, 1), (2048, 1), (2, 1)),
    str((32, 1001, 2048, "float16", False, True)): ((1, 1), (77, 1), (256, 1)),
    str((1001, 2048, 32, "float16", True, False)): ((1, 1), (2048, 1), (4, 1)),
    str((32, 2048, 1001, "float16", False, False)): ((1, 1), (2048, 1), (4, 1)),
    str((32, 1001, 2048, "float32", False, True)): ((1, 1), (7, 1), (2048, 1)),
    str((1001, 2048, 32, "float32", True, False)): ((1, 1), (2048, 1), (4, 1)),
    str((32, 2048, 1001, "float32", False, False)): ((1, 1), (2048, 1), (4, 1)),
    str((768, 3072, 131072, "float32", True, False)): ((2, 1), (3072, 1), (2, 1)),
    str((3072, 768, 131072, "float32", True, False)): ((2, 1), (768, 1), (8, 1)),
    str((65536, 768, 3072, "float32", False, True)): ((1, 1), (8, 1), (3072, 1)),
    str((131072, 3072, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
    str((32768, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
    str((65536, 1024, 4096, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
    str((131072, 1024, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
    str((131072, 768, 3072, "float32", False, True)): ((1, 1), (8, 1), (3072, 1)),
    str((65536, 3072, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
    str((32768, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
    str((65536, 4096, 1024, "float32", False, True)): ((2, 1), (8, 1), (1024, 1)),
    str((10240, 21128, 1024, "float32", False, True)): ((1, 1), (16, 1), (1024, 1)),
    str((21128, 1024, 20480, "float32", True, False)): ((4, 1), (1024, 1), (4, 1)),
    str((2048, 3072, 768, "float16", False, False)): ((1, 1), (768, 1), (16, 1)),
    str((2048, 768, 3072, "float16", False, False)): ((2, 1), (768, 1), (8, 1)),
    str((10240, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
    str((20480, 768, 21128, "float32", False, False)): ((2, 1), (768, 1), (8, 1)),
    str((21128, 768, 20480, "float32", True, False)): ((8, 1), (768, 1), (8, 1)),
    str((20480, 21128, 768, "float32", False, True)): ((1, 1), (16, 1), (768, 1)),
    str((32, 10, 2048, "float32", False, True)): ((1, 1), (8, 1), (2048, 1)),
    str((32, 10, 2048, "float16", False, True)): ((1, 1), (8, 1), (2048, 1)),
    str((32, 10, 4096, "float16", False, True)): ((1, 1), (2, 1), (4096, 1)),
    str((768, 768, 16384, "float16", True, False)): ((64, 1), (32, 1), (8, 1)),
    str((3072, 768, 16384, "float16", True, False)): ((64, 1), (32, 1), (8, 1)),
    str((768, 3072, 16384, "float16", True, False)): ((64, 1), (32, 1), (8, 1)),
    str((32, 9216, 4096, "float32", False, False)): ((1, 1), (9216, 1), (1, 1)),
    str((128, 128, 64, 3072, "float32", True, False)): ((1, 1), (64, 1), (64, 1), (1, 1)),
    str((21128, 1024, 10240, "float32", True, False)): ((8, 1), (1024, 1), (1, 1)),
    str((128, 128, 64, 1536, "float16", True, False)): ((1, 1), (1, 1), (16, 1), (1536, 1)),
    str((3072, 768, 2048, "float16", True, False)): ((1, 1), (16, 1), (1024, 1)),
    str((768, 3072, 2048, "float16", True, False)): ((1, 1), (16, 1), (1024, 1)),
    str((4096, 1024, 65536, "float32", True, False)): ((4, 1), (1024, 1), (4, 1)),  # auto tiling crash
    str((1024, 4096, 65536, "float32", True, False)): ((1, 1), (4096, 1), (4, 1)),  # auto tiling crash
    str((16384, 3072, 768, "float16", False, True)): ((1, 1), (16, 1), (768, 1)),  # auto tiling crash
    str((16384, 768, 3072, "float16", False, True)): ((1, 1), (16, 1), (1024, 1)),  # auto tiling crash
    str((1024, 4096, 1024, "float16", False, True)): ((1, 1), (16, 1), (1024, 1)),  # auto tiling crash
    str((131072, 1024, 1024, "float32", False, False)): ((2, 1), (1024, 1), (8, 1)),

    # Alexnet shape
    str((32, 10, 4096, "float32", False, True)): ((32, 1), (10, 1), (1, 1)),  # auto tiling crash
    str((32, 4096, 4096, "float16", False, False)): ((1, 1), (16, 1), (512, 1)),  # auto tiling crash
    str((32, 9216, 4096, "float16", False, False)): ((1, 1), (16, 1), (512, 1)),  # auto tiling crash
    str((32, 4096, 9216, "float16", False, True)): ((1, 1), (16, 1), (1024, 1)),  # auto tiling crash
    str((128, 128, 64, 1536, "float16", True, False)): ((1, 1), (1, 1), (16, 1), (96, 1)),  # auto tiling crash
    str((32, 4096, 4096, "float32", False, False)): ((1, 1), (256, 1), (64, 1)),  # performance optimization
    # Alexnet shape

    str((768, 3072, 16384, "float16", True, False)): ((64, 1), (32, 1), (8, 1)),

    str((3072, 768, 4096, "float16", True, False)): ((2, 1), (768, 1), (8, 1)),
    str((768, 3072, 4096, "float16", True, False)): ((2, 1), (3072, 1), (2, 1)),

    str((400, 120, 32, "float16", True, False)): ((2, 1), (32, 1), (2, 1)),
}

CORE_NUM = 2 if product_is_mini() else 32
MINIMAL_FOR_MULTICORE = CORE_NUM * 512


def get_best_align_elem(tensor_size, tensor_dtype):
    """Get the best tiling factor for alignment axis."""
    basic_align_elem = int(ct_util.BLOCK_SIZE / get_bytes(tensor_dtype))
    lcm = least_common_multiple(tensor_size, basic_align_elem)
    gcd = greatest_common_divisor(tensor_size, basic_align_elem)
    if gcd != 1:
        return gcd
    if lcm < tensor_size:
        return min(tensor_size, lcm)
    return -1


def get_shape_pos_map(tensor_shape):
    """Mapping tensor shape to corresponding axis position."""
    batch_pos = [i for i in range(len(tensor_shape) - 3) if tensor_shape[i] != 1]
    mnk = dict()
    pos_map = {0: "m", 1: "n", 2: "k"}
    count = -1
    for i, _ in enumerate(batch_pos):
        count += 1
        mnk["b%s" % str(i)] = count
    for i, shp in enumerate(tensor_shape[-3:]):
        if shp != 1:
            count += 1
        mnk[pos_map[i]] = count
    return batch_pos, mnk


def batchmatmul_tiling_strategy(shape, align_dtype, attrs):
    """This is an efficient version of tiling strategy for batchmatmul."""
    if len(shape) < 3:
        raise RuntimeError("Shape must be in the form of [(Batch_out, Batch_in,) M, N, K]. "
                           "Current length of shape is {}".format(len(shape)))
    strategy = list()
    m, n, k = shape[-3:]
    batch_pos, mnk = get_shape_pos_map(shape)
    total_size = reduce(lambda x, y: int(x) * int(y), shape)

    # set minimal tile for n as the basic block size for alignment
    align_elem = int(ct_util.BLOCK_SIZE / get_bytes(align_dtype))
    tile_n = get_best_align_elem(n, align_dtype)

    # if n is smaller than block size, it is not safe to open multi-core for now
    if n < align_elem:
        n_constraint = [ct_util.TileConstraint.FACTOR]
        used_core = CORE_NUM
        attrs["enable_multicore"] = 0
    else:
        n_constraint = [ct_util.TileConstraint.MOD, ct_util.TileConstraint.MIN]
        used_core = 1

    if total_size >= MINIMAL_FOR_MULTICORE and used_core < CORE_NUM:
        # set maximal tile for batch according to multi-core usage
        for i, p in enumerate(batch_pos):
            use = min(shape[p], int((CORE_NUM - 1 + used_core) / used_core))
            used_core *= use
            max_b = int(shape[p] / use)
            strategy.append(ct_util.create_constraint_on_axis(values=max_b,
                                                              constraints=ct_util.TileConstraint.MAX,
                                                              band=0,
                                                              axis=i)[0])

    tile_k = get_best_align_elem(k, align_dtype)
    k_constraint = ct_util.TileConstraint.MIN
    total_size /= max(1, int(k / tile_k))

    # set minimal tile for m according to multi-core usage when there is no expansion
    tile_m = -1
    m_constraint = ct_util.TileConstraint.MAX
    m_per_block = m
    max_core = min(CORE_NUM, int(total_size / MINIMAL_FOR_MULTICORE))
    if greatest_common_divisor(n, align_elem) != 1 and used_core < max_core:
        left_core = int((max_core - 1 + used_core) / used_core)
        core_limit = max(1, int(m / greatest_common_divisor(left_core, m)))
        nk_in_mem = int(n / max(1, tile_n)) * int(k / tile_k)
        balance_limit = max(1, int(m / greatest_common_divisor(nk_in_mem, m)))
        tile_m = min(core_limit, balance_limit)
        m_per_block = int(m / tile_m)

    # for large m case, it is more efficient to balance memory bound and calculation bound
    if m_per_block > int(n / max(1, tile_n)) * int(k / tile_k):
        tile_m = max(min(m, align_elem), tile_m)
        k_constraint = ct_util.TileConstraint.FACTOR

    # create constraints based on previous analysis
    if m != 1:
        strategy.append(ct_util.create_constraint_on_axis(values=tile_m,
                                                          constraints=m_constraint,
                                                          band=0,
                                                          axis=mnk["m"])[0])
    if n != 1:
        for constraint in n_constraint:
            strategy.append(ct_util.create_constraint_on_axis(values=tile_n,
                                                              constraints=constraint,
                                                              band=0,
                                                              axis=mnk["n"])[0])

    if k != 1:
        strategy.append(ct_util.create_constraint_on_axis(values=tile_k,
                                                          constraints=k_constraint,
                                                          band=0,
                                                          axis=mnk["k"])[0])
    higher_priority_pos = mnk["k"] if k >= n else mnk["n"]
    strategy.append(ct_util.create_constraint_on_axis(values=0,
                                                      constraints=ct_util.TileConstraint.SET_PRIORITY,
                                                      band=0,
                                                      axis=higher_priority_pos)[0])
    strategy.append(ct_util.modify_common_constraints(0.7, ct_util.TileConstraint.SET_MEM_RATIO))
    attrs["custom_tiling"] = strategy
    return attrs

def batchmatmul_tiling_strategy_dynamic(shape, output, attrs):
    """This is an efficient version of tiling strategy for batchmatmul."""
    if len(shape) < 3:
        raise RuntimeError("Shape must be in the form of [(Batch_out, Batch_in,) M, N, K]. "
                           "Current length of shape is {}".format(len(shape)))
    strategy = list()
    _, mnk = get_shape_pos_map(shape)
    # create constraints based on previous analysis
    strategy.append(ct_util.create_constraint_on_axis(values=1,
                                                      constraints=ct_util.TileConstraint.FACTOR,
                                                      band=0,
                                                      axis=mnk["m"])[0])
    strategy.append(ct_util.create_constraint_on_axis(values="FULL",
                                                      constraints=ct_util.TileConstraint.MAX,
                                                      band=0,
                                                      axis=mnk["n"])[0])
    strategy.append(ct_util.create_constraint_on_axis(values=8,
                                                      constraints=ct_util.TileConstraint.FACTOR,
                                                      band=0,
                                                      axis=mnk["k"])[0])
    strategy.append(ct_util.modify_common_constraints(0.7, ct_util.TileConstraint.SET_MEM_RATIO))

    attrs["custom_tiling"] = strategy
    attrs["dynamic_shape"] = ds.set_dynamic_shape_limit_for_tensor(output, 2048, [1,])
    return attrs

def get_mnk_from_matrix(shape_a_list, shape_b_list, trans_a, trans_b):
    """Get m, n and k value from input tensor shapes."""
    m, k = shape_a_list[-2], shape_a_list[-1]
    if trans_a:
        m, k = k, m

    n = shape_b_list[-2] if trans_b else shape_b_list[-1]
    return [m, n, k]


def batchmatmul_set_dim(a_value, b_value, trans_a, trans_b):
    """This function is used to set dim info in attrs by set_dim_map."""
    shape_a_list = get_shape(a_value)
    shape_b_list = get_shape(b_value)
    m, n, k = get_mnk_from_matrix(shape_a_list, shape_b_list, trans_a, trans_b)

    key = ()
    if len(shape_a_list) > 2:
        key += tuple(shape_a_list[:-2])

    key += (m, n, k, a_value.dtype, trans_a, trans_b)
    set_dims = ct_util.set_dims_by_key(str(key), batchmatmul_set_dim_map)

    return set_dims, str(key)


def batchmatmul_bias_set_dim(a_value, b_value, bias_value, trans_a, trans_b):
    """This function is used to set dim info in attrs by set_dim_map of batchmatmul with bias."""
    return batchmatmul_set_dim(a_value, b_value, trans_a, trans_b)


def batchmatmul_no_bias_set_dim(a_value, b_value, trans_a, trans_b):
    """This function is used to set dim info in attrs by set_dim_map of batchmatmul without bias."""
    return batchmatmul_set_dim(a_value, b_value, trans_a, trans_b)


@ct_util.reg_set_dim_func(batchmatmul_bias_set_dim)
@vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, bool, bool)
def batchmatmul_bias(a_value, b_value, bias_value, trans_a, trans_b):
    """
    Multiplies two tensors in batches and adds bias to the output.

    Args:
        a_value (tvm.tensor.Tensor): 2D or 3D or 4D tensor of type float16 or float32 with shape(..., r_A, c_A).
        b_value (tvm.tensor.Tensor): 2D or 3D or 4D tensor of same type as a_value with shape(..., r_B, c_B).
        bias_value (tvm.tensor.Tensor): The bias tensor added to the result of a_value * b_value.
                                        Should be of same type as a_value, broadcast is allowed.
        trans_a (bool): Specifies whether a_value is transposed or not, default value is False.
        trans_b (bool): Specifies whether b_value is transposed or not, default value is False.

    Returns:
        tvm.tensor.Tensor of same type as a_value with shape(..., r_C, c_C).
            r_C = c_A if trans_a else r_A
            c_C = r_B if trans_b else c_B
    """

    if not isinstance(trans_a, bool):
        raise TypeError("trans_a should be of type Boolean.")
    if not isinstance(trans_b, bool):
        raise TypeError("trans_b should be of type Boolean.")
    vc_util.ops_dtype_check([a_value.dtype, b_value.dtype, bias_value.dtype], vc_util.DtypeForDavinci.ALL_FLOAT)
    vc_util.elemwise_dtype_check(a_value.dtype, b_value.dtype)
    vc_util.elemwise_dtype_check(a_value.dtype, bias_value.dtype)
    vc_util.gemm_format_check(get_shape(a_value), get_shape(b_value), trans_a, trans_b)
    if len(a_value.shape) not in [2, 3, 4]:
        raise ValueError("Batch matmul only support 2D, 3D and 4D now.")

    c_value = batchmatmul(a_value, b_value, trans_a, trans_b)
    if isinstance(c_value, (tuple, list)):
        c_value = c_value[0]

    vc_util.auto_broadcast_check(get_shape(bias_value), get_shape(c_value))

    shape_c_list = get_shape(c_value)
    bias_value = akg.topi.broadcast_to(bias_value, shape_c_list)
    dim_info = batchmatmul_bias_set_dim(a_value, b_value, bias_value, trans_a, trans_b)
    if isinstance(dim_info, (tuple, list)):
        dim_info = dim_info[0]
    attrs = {}
    attrs["enable_compute_in_place"] = True
    if dim_info != "":
        attrs["dim"] = dim_info
    batch = get_shape(a_value)[:-2]
    mnk = get_mnk_from_matrix(get_shape(a_value), get_shape(b_value), trans_a, trans_b)
    attrs = batchmatmul_tiling_strategy(batch + mnk, c_value.dtype, attrs)
    return akg.tvm.compute(bias_value.shape,
                           lambda *indice: c_value(*indice) + bias_value(*indice), name='matmul_bias_output'), attrs


@ct_util.reg_set_dim_func(batchmatmul_no_bias_set_dim)
@vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor, bool, bool)
def batchmatmul(a_value, b_value, trans_a=False, trans_b=False):
    """
    Multiplies two tensors in batches.

    Args:
        a_value (tvm.tensor.Tensor): 2D or 3D or 4D tensor of type float16 or float32 with shape(..., r_A, c_A).
        b_value (tvm.tensor.Tensor): 2D or 3D or 4D tensor of same type as a_value with shape(..., r_B, c_B).
        trans_a (bool): Specifies whether a_value is transposed or not, default value is False.
        trans_b (bool): Specifies whether b_value is transposed or not, default value is False.

    Returns:
        tvm.tensor.Tensor of same type as a_value with shape(..., r_C, c_C).
            r_C = c_A if trans_a else r_A
            c_C = r_B if trans_b else c_B
    """

    if not isinstance(trans_a, bool):
        raise TypeError("trans_a should be of type Boolean.")
    if not isinstance(trans_b, bool):
        raise TypeError("trans_b should be of type Boolean.")
    vc_util.ops_dtype_check([a_value.dtype, b_value.dtype], vc_util.DtypeForDavinci.ALL_FLOAT)
    vc_util.elemwise_dtype_check(a_value.dtype, b_value.dtype)
    vc_util.gemm_format_check(get_shape(a_value), get_shape(b_value), trans_a, trans_b)
    if len(a_value.shape) not in [2, 3, 4]:
        raise ValueError("Batch matmul only support 2D, 3D and 4D now.")

    dtype = a_value.dtype
    if dtype == 'float16':
        if len(a_value.shape) == 2:
            c_value = vectormatmul_2d_cast(a_value, b_value, trans_a, trans_b, "float32")
        elif len(a_value.shape) == 3:
            c_value = vectormatmul_3d_cast(a_value, b_value, trans_a, trans_b, "float32")
        else:
            c_value = vectormatmul_4d_cast(a_value, b_value, trans_a, trans_b, "float32")

    else:
        if len(a_value.shape) == 2:
            c_value = vectormatmul_2d(a_value, b_value, trans_a, trans_b)
        elif len(a_value.shape) == 3:
            c_value = vectormatmul_3d(a_value, b_value, trans_a, trans_b)
        else:
            c_value = vectormatmul_4d(a_value, b_value, trans_a, trans_b)

    dim_info = batchmatmul_no_bias_set_dim(a_value, b_value, trans_a, trans_b)
    if isinstance(dim_info, (tuple, list)):
        dim_info = dim_info[0]
    attrs = {}
    attrs["enable_compute_in_place"] = True
    if dim_info != "":
        attrs["dim"] = dim_info

    mnk = get_mnk_from_matrix(get_shape(a_value), get_shape(b_value), trans_a, trans_b)
    batch = get_shape(a_value)[:-2]
    is_dynamic = ds.shape_is_dynamic([a_value, b_value])
    if not is_dynamic:
        attrs = batchmatmul_tiling_strategy(batch + mnk, c_value.dtype, attrs)
    else:
        attrs = batchmatmul_tiling_strategy_dynamic(batch + mnk, c_value, attrs)
        attrs["enable_pre_storage_write_simplify"] = True
        attrs["enable_sink_allocate"] = True
        attrs["enable_double_buffer"] = False

    return c_value, attrs


def vectormatmul_3d(a_value, b_value, trans_a, trans_b):
    """hybrid implementation for 3D batchmatmul."""
    if trans_a:
        bs, k, m = a_value.shape
    else:
        bs, m, k = a_value.shape
    if trans_b:
        n = b_value.shape[-2]
    else:
        n = b_value.shape[-1]

    dtype = a_value.dtype
    zero = akg.tvm.const(0.0, dtype=dtype)

    @script(capture=locals())
    def matmul_hybrid_f_f(a, b, zero):
        t_1 = allocate((bs, m, k, n), a.dtype, 'local')
        t_2 = allocate((bs, m, n), a.dtype, 'local')
        for i_bs in range(0, bs):
            for i_m in range(0, m):
                for i_k in range(0, k):
                    for i_n in range(0, n):
                        t_1[i_bs, i_m, i_k, i_n] = a[i_bs, i_m, i_k] * b[i_bs, i_k, i_n]
                for i1_n in range(0, n):
                    t_2[i_bs, i_m, i1_n] = zero
                for i1_k in range(0, k):
                    for i1_n in range(0, n):
                        t_2[i_bs, i_m, i1_n] = t_2[i_bs, i_m, i1_n] + t_1[i_bs, i_m, i1_k, i1_n]
        return t_2

    @script(capture=locals())
    def matmul_hybrid_f_t(a, b, zero):
        t_1 = allocate((bs, m, n, k), a.dtype, 'local')
        t_2 = allocate((bs, m, n), a.dtype, 'local')
        for i_bs in range(0, bs):
            for i_m in range(0, m):
                for i_n in range(0, n):
                    t_2[i_bs, i_m, i_n] = zero
                    for i_k in range(0, k):
                        t_1[i_bs, i_m, i_n, i_k] = a[i_bs, i_m, i_k] * b[i_bs, i_n, i_k]
                        t_2[i_bs, i_m, i_n] = t_1[i_bs, i_m, i_n, i_k] + t_2[i_bs, i_m, i_n]
        return t_2

    @script(capture=locals())
    def matmul_hybrid_t_f(a, b, zero):
        t_1 = allocate((bs, m, k, n), a.dtype, 'local')
        t_2 = allocate((bs, m, n), a.dtype, 'local')
        for i_bs in range(0, bs):
            for i_m in range(0, m):
                for i_k in range(0, k):
                    for i_n in range(0, n):
                        t_1[i_bs, i_m, i_k, i_n] = a[i_bs, i_k, i_m] * b[i_bs, i_k, i_n]
                for i1_n in range(0, n):
                    t_2[i_bs, i_m, i1_n] = zero
                for i1_k in range(0, k):
                    for i1_n in range(0, n):
                        t_2[i_bs, i_m, i1_n] = t_2[i_bs, i_m, i1_n] + t_1[i_bs, i_m, i1_k, i1_n]
        return t_2

    if not trans_a and not trans_b:
        c_value = matmul_hybrid_f_f(a_value, b_value, zero)
    elif not trans_a and trans_b:
        c_value = matmul_hybrid_f_t(a_value, b_value, zero)
    elif trans_a and not trans_b:
        c_value = matmul_hybrid_t_f(a_value, b_value, zero)
    else:
        raise ValueError('Not support both transpose yet')

    return c_value


def vectormatmul_4d(a_value, b_value, trans_a, trans_b):
    """hybrid implementation for 4D batchmatmul."""
    if trans_a:
        bs1, bs2, k, m = a_value.shape
    else:
        bs1, bs2, m, k = a_value.shape
    if trans_b:
        n = b_value.shape[-2]
    else:
        n = b_value.shape[-1]

    dtype = a_value.dtype
    zero = akg.tvm.const(0.0, dtype=dtype)

    @script(capture=locals())
    def matmul_hybrid_f_f(a, b, zero):
        t_1 = allocate((bs1, bs2, m, k, n), a.dtype, 'local')
        t_2 = allocate((bs1, bs2, m, n), a.dtype, 'local')
        for i_bs1 in range(0, bs1):
            for i_bs2 in range(0, bs2):
                for i_m in range(0, m):
                    for i_k in range(0, k):
                        for i_n in range(0, n):
                            t_1[i_bs1, i_bs2, i_m, i_k, i_n] = a[i_bs1, i_bs2, i_m, i_k] * b[i_bs1, i_bs2, i_k, i_n]
                    for i1_n in range(0, n):
                        t_2[i_bs1, i_bs2, i_m, i1_n] = zero
                    for i1_k in range(0, k):
                        for i1_n in range(0, n):
                            t_2[i_bs1, i_bs2, i_m, i1_n] = t_2[i_bs1, i_bs2, i_m, i1_n] + \
                                t_1[i_bs1, i_bs2, i_m, i1_k, i1_n]
        return t_2

    @script(capture=locals())
    def matmul_hybrid_f_t(a, b, zero):
        t_1 = allocate((bs1, bs2, m, n, k), a.dtype, 'local')
        t_2 = allocate((bs1, bs2, m, n), a.dtype, 'local')
        for i_bs1 in range(0, bs1):
            for i_bs2 in range(0, bs2):
                for i_m in range(0, m):
                    for i_n in range(0, n):
                        t_2[i_bs1, i_bs2, i_m, i_n] = zero
                        for i_k in range(0, k):
                            t_1[i_bs1, i_bs2, i_m, i_n, i_k] = a[i_bs1, i_bs2, i_m, i_k] * b[i_bs1, i_bs2, i_n, i_k]
                            t_2[i_bs1, i_bs2, i_m, i_n] = t_1[i_bs1, i_bs2, i_m, i_n, i_k] + t_2[i_bs1, i_bs2, i_m, i_n]
        return t_2

    @script(capture=locals())
    def matmul_hybrid_t_f(a, b, zero):
        t_1 = allocate((bs1, bs2, m, k, n), a.dtype, 'local')
        t_2 = allocate((bs1, bs2, m, n), a.dtype, 'local')
        for i_bs1 in range(0, bs1):
            for i_bs2 in range(0, bs2):
                for i_m in range(0, m):
                    for i_k in range(0, k):
                        for i_n in range(0, n):
                            t_1[i_bs1, i_bs2, i_m, i_k, i_n] = a[i_bs1, i_bs2, i_k, i_m] * b[i_bs1, i_bs2, i_k, i_n]
                    for i1_n in range(0, n):
                        t_2[i_bs1, i_bs2, i_m, i1_n] = zero
                    for i1_k in range(0, k):
                        for i1_n in range(0, n):
                            t_2[i_bs1, i_bs2, i_m, i1_n] = t_2[i_bs1, i_bs2, i_m, i1_n] + \
                                t_1[i_bs1, i_bs2, i_m, i1_k, i1_n]
        return t_2

    if not trans_a and not trans_b:
        c_value = matmul_hybrid_f_f(a_value, b_value, zero)
    elif not trans_a and trans_b:
        c_value = matmul_hybrid_f_t(a_value, b_value, zero)
    elif trans_a and not trans_b:
        c_value = matmul_hybrid_t_f(a_value, b_value, zero)
    else:
        raise ValueError('Not support both transpose yet')

    return c_value


def vectormatmul_4d_cast(a_value, b_value, trans_a, trans_b, cast_dtype):
    """dsl implementation with data type cast for 4D batchmatmul."""
    if trans_a:
        b1, b2, k, m = a_value.shape
    else:
        b1, b2, m, k = a_value.shape
    if trans_b:
        n = b_value.shape[-2]
    else:
        n = b_value.shape[-1]

    dtype = a_value.dtype

    def matmul_4d_dsl(a_value, b_value, trans_a, trans_b):
        if not trans_a and not trans_b:
            ele_mul = akg.tvm.compute((b1, b2, m, n, k),
                                      lambda i_b1, i_b2, i_m, i_n, i_k:
                                      a_value[i_b1, i_b2, i_m, i_k].astype(cast_dtype) *
                                      b_value[i_b1, i_b2, i_k, i_n].astype(cast_dtype),
                                      name="ele_mul")
        elif not trans_a and trans_b:
            ele_mul = akg.tvm.compute((b1, b2, m, n, k),
                                      lambda i_b1, i_b2, i_m, i_n, i_k:
                                      a_value[i_b1, i_b2, i_m, i_k].astype(cast_dtype) *
                                      b_value[i_b1, i_b2, i_n, i_k].astype(cast_dtype),
                                      name="ele_mul")
        elif trans_a and not trans_b:
            ele_mul = akg.tvm.compute((b1, b2, m, n, k),
                                      lambda i_b1, i_b2, i_m, i_n, i_k:
                                      a_value[i_b1, i_b2, i_k, i_m].astype(cast_dtype) *
                                      b_value[i_b1, i_b2, i_k, i_n].astype(cast_dtype),
                                      name="ele_mul")
        elif trans_a and trans_b:
            ele_mul = akg.tvm.compute((b1, b2, m, n, k),
                                      lambda i_b1, i_b2, i_m, i_n, i_k:
                                      b_value[i_b1, i_b2, i_n, i_k].astype(cast_dtype) *
                                      a_value[i_b1, i_b2, i_k, i_m].astype(cast_dtype),
                                      name="ele_mul")
        reduce_axis = akg.tvm.reduce_axis((0, k), name='reduce_axis')
        output_shape = (b1, b2, m, n)
        m_c = akg.tvm.compute(output_shape, lambda *i: akg.tvm.sum(ele_mul[i + (reduce_axis,)],
                                                                   axis=reduce_axis), name="matmul_compute")
        return m_c

    c_cast = matmul_4d_dsl(a_value, b_value, trans_a, trans_b)
    c_value = cast.cast(c_cast, dtype)
    if trans_a and trans_b:
        c_res = akg.topi.transpose(c_value, (1, 0))
        return c_res
    return c_value


def vectormatmul_3d_cast(a_value, b_value, trans_a, trans_b, cast_dtype):
    """dsl implementation with data type cast for 3D batchmatmul."""
    if trans_a:
        b, k, m = a_value.shape
    else:
        b, m, k = a_value.shape
    if trans_b:
        n = b_value.shape[-2]
    else:
        n = b_value.shape[-1]

    dtype = a_value.dtype

    def matmul_3d_dsl(a_value, b_value, trans_a, trans_b):
        if not trans_a and not trans_b:
            ele_mul = akg.tvm.compute((b, m, n, k), lambda i_b, i_m, i_n, i_k:
                                      a_value[i_b, i_m, i_k].astype(cast_dtype) *
                                      b_value[i_b, i_k, i_n].astype(cast_dtype),
                                      name="ele_mul")
        elif not trans_a and trans_b:
            ele_mul = akg.tvm.compute((b, m, n, k), lambda i_b, i_m, i_n, i_k:
                                      a_value[i_b, i_m, i_k].astype(cast_dtype) *
                                      b_value[i_b, i_n, i_k].astype(cast_dtype),
                                      name="ele_mul")
        elif trans_a and not trans_b:
            ele_mul = akg.tvm.compute((b, m, n, k), lambda i_b, i_m, i_n, i_k:
                                      a_value[i_b, i_k, i_m].astype(cast_dtype) *
                                      b_value[i_b, i_k, i_n].astype(cast_dtype),
                                      name="ele_mul")
        elif trans_a and trans_b:
            ele_mul = akg.tvm.compute((b, m, n, k), lambda i_b, i_m, i_n, i_k:
                                      b_value[i_b, i_n, i_k].astype(cast_dtype) *
                                      a_value[i_b, i_k, i_m].astype(cast_dtype),
                                      name="ele_mul")
        reduce_axis = akg.tvm.reduce_axis((0, k), name='reduce_axis')
        output_shape = (b, m, n)
        m_c = akg.tvm.compute(output_shape, lambda *i: akg.tvm.sum(ele_mul[i + (reduce_axis,)], axis=reduce_axis),
                              name="matmul_compute")
        return m_c

    c_cast = matmul_3d_dsl(a_value, b_value, trans_a, trans_b)
    c_value = cast.cast(c_cast, dtype)
    if trans_a and trans_b:
        c_res = akg.topi.transpose(c_value, (1, 0))
        return c_res
    return c_value


def vectormatmul_2d_cast(a_value, b_value, trans_a, trans_b, cast_dtype):
    """hybrid implementation with data type cast for 2D batchmatmul."""
    if trans_a:
        k, m = a_value.shape
    else:
        m, k = a_value.shape
    if trans_b:
        n = b_value.shape[-2]
    else:
        n = b_value.shape[-1]

    dtype = a_value.dtype

    # When the float16 cast to float32 directly, the AutoPoly pass cost a long time.
    # Therefore, the cast be done in single element.
    def matmul_2d(a_value, b_value, trans_a, trans_b):
        if not trans_a and not trans_b:
            ele_mul = akg.tvm.compute((m, n, k), lambda i_m, i_n, i_k: a_value[i_m, i_k].astype(cast_dtype)
                                      * b_value[i_k, i_n].astype(cast_dtype),
                                      name="ele_mul")
        elif not trans_a and trans_b:
            ele_mul = akg.tvm.compute((m, n, k), lambda i_m, i_n, i_k: a_value[i_m, i_k].astype(cast_dtype) *
                                      b_value[i_n, i_k].astype(cast_dtype),
                                      name="ele_mul")
        elif trans_a and not trans_b:
            ele_mul = akg.tvm.compute((m, n, k), lambda i_m, i_n, i_k: a_value[i_k, i_m].astype(cast_dtype) *
                                      b_value[i_k, i_n].astype(cast_dtype),
                                      name="ele_mul")
        elif trans_a and trans_b:
            ele_mul = akg.tvm.compute((m, n, k), lambda i_m, i_n, i_k: b_value[i_n, i_k].astype(cast_dtype) *
                                      a_value[i_k, i_m].astype(cast_dtype),
                                      name="ele_mul")
        reduce_axis = akg.tvm.reduce_axis((0, k), name='reduce_axis')
        output_shape = (m, n)
        m_c = akg.tvm.compute(output_shape, lambda *i: akg.tvm.sum(ele_mul[i + (reduce_axis,)], axis=reduce_axis),
                              name="matmul_compute")
        return m_c

    c_cast = matmul_2d(a_value, b_value, trans_a, trans_b)
    c_value = cast.cast(c_cast, dtype)
    if trans_a and trans_b:
        c_res = akg.topi.transpose(c_value, (1, 0))
        return c_res
    return c_value


def vectormatmul_2d(a_value, b_value, trans_a, trans_b):
    """hybrid implementation for 2D batchmatmul."""
    if trans_a:
        k, m = a_value.shape
    else:
        m, k = a_value.shape
    if trans_b:
        n = b_value.shape[-2]
    else:
        n = b_value.shape[-1]

    dtype = a_value.dtype
    zero = akg.tvm.const(0.0, dtype=dtype)

    @script(capture=locals())
    def matmul_hybrid_f_f(a, b, zero, mv, nv, kv):
        t_1 = allocate((mv, kv, nv), a.dtype, 'local')
        t_2 = output_tensor((mv, nv), a.dtype)
        for i_m in range(0, mv):
            for i_k in range(0, kv):
                for i_n in range(0, nv):
                    t_1[i_m, i_k, i_n] = a[i_m, i_k] * b[i_k, i_n]
            for i1_n in range(0, nv):
                t_2[i_m, i1_n] = zero
            for i1_k in range(0, kv):
                for i1_n in range(0, nv):
                    t_2[i_m, i1_n] = t_2[i_m, i1_n] + t_1[i_m, i1_k, i1_n]
        return t_2

    @script(capture=locals())
    def matmul_hybrid_f_t(a, b, zero, mv, nv, kv):
        t_1 = allocate((mv, nv, kv), a.dtype, 'local')
        t_2 = allocate((mv, nv), a.dtype, 'local')
        for i_m in range(0, mv):
            for i_n in range(0, nv):
                t_2[i_m, i_n] = zero
                for i_k in range(0, kv):
                    t_1[i_m, i_n, i_k] = a[i_m, i_k] * b[i_n, i_k]
                    t_2[i_m, i_n] = t_1[i_m, i_n, i_k] + t_2[i_m, i_n]
        return t_2

    @script(capture=locals())
    def matmul_hybrid_t_f(a, b, zero, mv, nv, kv):
        t_1 = allocate((mv, kv, nv), a.dtype, 'local')
        t_2 = allocate((mv, nv), a.dtype, 'local')
        for i_m in range(0, mv):
            for i_k in range(0, kv):
                for i_n in range(0, nv):
                    t_1[i_m, i_k, i_n] = a[i_k, i_m] * b[i_k, i_n]
            for i1_n in range(0, nv):
                t_2[i_m, i1_n] = zero
            for i1_k in range(0, kv):
                for i1_n in range(0, nv):
                    t_2[i_m, i1_n] = t_2[i_m, i1_n] + t_1[i_m, i1_k, i1_n]
        return t_2

    @script(capture=locals())
    def matmul_hybrid_t_t(a, b, zero, mv, nv, kv):
        t_1 = allocate((nv, kv, mv), a.dtype, 'local')
        t_2 = allocate((nv, mv), a.dtype, 'local')
        for i_n in range(0, nv):
            for i_m in range(0, mv):
                for i_k in range(0, kv):
                    t_1[i_n, i_k, i_m] = b[i_n, i_k] * a[i_k, i_m]
            for i1_m in range(0, mv):
                t_2[i_n, i1_m] = zero
            for i1_k in range(0, kv):
                for i2_m in range(0, mv):
                    t_2[i_n, i2_m] = t_2[i_n, i2_m] + t_1[i_n, i1_k, i2_m]
        return t_2

    if not trans_a and not trans_b:
        c_value = matmul_hybrid_f_f(a_value, b_value, zero, m, n, k)
    elif not trans_a and trans_b:
        c_value = matmul_hybrid_f_t(a_value, b_value, zero, m, n, k)
    elif trans_a and not trans_b:
        c_value = matmul_hybrid_t_f(a_value, b_value, zero, m, n, k)
    else:
        c1 = matmul_hybrid_t_t(a_value, b_value, zero, m, n, k)
        c_value = akg.topi.transpose(c1, (1, 0))

    return c_value
